package storm.trident;

import storm.trident.fluent.ChainedAggregatorDeclarer;
import storm.trident.fluent.GlobalAggregationScheme;
import storm.trident.fluent.GroupedStream;
import storm.trident.fluent.IAggregatableStream;
import storm.trident.operation.*;
import storm.trident.operation.impl.*;
import storm.trident.operation.impl.SingleEmitAggregator.BatchToPartition;
import storm.trident.partition.GlobalGrouping;
import storm.trident.partition.IdentityGrouping;
import storm.trident.partition.IndexHashGrouping;
import storm.trident.planner.Node;
import storm.trident.planner.NodeStateInfo;
import storm.trident.planner.PartitionNode;
import storm.trident.planner.ProcessorNode;
import storm.trident.planner.processor.*;
import storm.trident.state.QueryFunction;
import storm.trident.state.StateFactory;
import storm.trident.state.StateSpec;
import storm.trident.state.StateUpdater;
import storm.trident.util.TridentUtils;
import backtype.storm.generated.Grouping;
import backtype.storm.generated.NullStruct;
import backtype.storm.grouping.CustomStreamGrouping;
import backtype.storm.tuple.Fields;
import backtype.storm.utils.Utils;


// TODO: need to be able to replace existing fields with the function fields (like Cascading Fields.REPLACE)
public class Stream implements IAggregatableStream {
    Node _node;
    TridentTopology _topology;
    String _name;


    protected Stream(TridentTopology topology, String name, Node node) {
        _topology = topology;
        _node = node;
        _name = name;
    }


    public Stream name(String name) {
        return new Stream(_topology, name, _node);
    }


    public Stream parallelismHint(int hint) {
        _node.parallelismHint = hint;
        return this;
    }


    public Stream project(Fields keepFields) {
        projectionValidation(keepFields);
        return _topology.addSourcedNode(this, new ProcessorNode(_topology.getUniqueStreamId(), _name,
            keepFields, new Fields(), new ProjectedProcessor(keepFields)));
    }


    public GroupedStream groupBy(Fields fields) {
        projectionValidation(fields);
        return new GroupedStream(this, fields);
    }


    public Stream partitionBy(Fields fields) {
        projectionValidation(fields);
        return partition(Grouping.fields(fields.toList()));
    }


    public Stream partition(CustomStreamGrouping partitioner) {
        return partition(Grouping.custom_serialized(Utils.serialize(partitioner)));
    }


    public Stream shuffle() {
        return partition(Grouping.shuffle(new NullStruct()));
    }


    public Stream global() {
        // use this instead of storm's built in one so that we can specify a
        // singleemitbatchtopartition
        // without knowledge of storm's internals
        return partition(new GlobalGrouping());
    }


    public Stream batchGlobal() {
        // the first field is the batch id
        return partition(new IndexHashGrouping(0));
    }


    public Stream broadcast() {
        return partition(Grouping.all(new NullStruct()));
    }


    public Stream identityPartition() {
        return partition(new IdentityGrouping());
    }


    public Stream partition(Grouping grouping) {
        if (_node instanceof PartitionNode) {
            return each(new Fields(), new TrueFilter()).partition(grouping);
        }
        else {
            return _topology.addSourcedNode(this, new PartitionNode(_node.streamId, _name, getOutputFields(),
                grouping));
        }
    }


    public Stream applyAssembly(Assembly assembly) {
        return assembly.apply(this);
    }


    @Override
    public Stream each(Fields inputFields, Function function, Fields functionFields) {
        projectionValidation(inputFields);
        return _topology.addSourcedNode(this, new ProcessorNode(_topology.getUniqueStreamId(), _name,
            TridentUtils.fieldsConcat(getOutputFields(), functionFields), functionFields, new EachProcessor(
                inputFields, function)));
    }


    // creates brand new tuples with brand new fields
    @Override
    public Stream partitionAggregate(Fields inputFields, Aggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return _topology.addSourcedNode(this, new ProcessorNode(_topology.getUniqueStreamId(), _name,
            functionFields, functionFields, new AggregateProcessor(inputFields, agg)));
    }


    public Stream stateQuery(TridentState state, Fields inputFields, QueryFunction function,
            Fields functionFields) {
        projectionValidation(inputFields);
        String stateId = state._node.stateInfo.id;
        Node n =
                new ProcessorNode(_topology.getUniqueStreamId(), _name, TridentUtils.fieldsConcat(
                    getOutputFields(), functionFields), functionFields, new StateQueryProcessor(stateId,
                    inputFields, function));
        _topology._colocate.get(stateId).add(n);
        return _topology.addSourcedNode(this, n);
    }


    public TridentState partitionPersist(StateFactory stateFactory, Fields inputFields, StateUpdater updater,
            Fields functionFields) {
        return partitionPersist(new StateSpec(stateFactory), inputFields, updater, functionFields);
    }


    public TridentState partitionPersist(StateSpec stateSpec, Fields inputFields, StateUpdater updater,
            Fields functionFields) {
        projectionValidation(inputFields);
        String id = _topology.getUniqueStateId();
        ProcessorNode n =
                new ProcessorNode(_topology.getUniqueStreamId(), _name, functionFields, functionFields,
                    new PartitionPersistProcessor(id, inputFields, updater));
        n.committer = true;
        n.stateInfo = new NodeStateInfo(id, stateSpec);
        return _topology.addSourcedStateNode(this, n);
    }


    public TridentState partitionPersist(StateFactory stateFactory, Fields inputFields, StateUpdater updater) {
        return partitionPersist(stateFactory, inputFields, updater, new Fields());
    }


    public TridentState partitionPersist(StateSpec stateSpec, Fields inputFields, StateUpdater updater) {
        return partitionPersist(stateSpec, inputFields, updater, new Fields());
    }


    public Stream each(Function function, Fields functionFields) {
        return each(null, function, functionFields);
    }


    public Stream each(Fields inputFields, Filter filter) {
        return each(inputFields, new FilterExecutor(filter), new Fields());
    }


    public ChainedAggregatorDeclarer chainedAgg() {
        return new ChainedAggregatorDeclarer(this, new BatchGlobalAggScheme());
    }


    public Stream partitionAggregate(Aggregator agg, Fields functionFields) {
        return partitionAggregate(null, agg, functionFields);
    }


    public Stream partitionAggregate(CombinerAggregator agg, Fields functionFields) {
        return partitionAggregate(null, agg, functionFields);
    }


    public Stream partitionAggregate(Fields inputFields, CombinerAggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return chainedAgg().partitionAggregate(inputFields, agg, functionFields).chainEnd();
    }


    public Stream partitionAggregate(ReducerAggregator agg, Fields functionFields) {
        return partitionAggregate(null, agg, functionFields);
    }


    public Stream partitionAggregate(Fields inputFields, ReducerAggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return chainedAgg().partitionAggregate(inputFields, agg, functionFields).chainEnd();
    }


    public Stream aggregate(Aggregator agg, Fields functionFields) {
        return aggregate(null, agg, functionFields);
    }


    public Stream aggregate(Fields inputFields, Aggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return chainedAgg().aggregate(inputFields, agg, functionFields).chainEnd();
    }


    public Stream aggregate(CombinerAggregator agg, Fields functionFields) {
        return aggregate(null, agg, functionFields);
    }


    public Stream aggregate(Fields inputFields, CombinerAggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return chainedAgg().aggregate(inputFields, agg, functionFields).chainEnd();
    }


    public Stream aggregate(ReducerAggregator agg, Fields functionFields) {
        return aggregate(null, agg, functionFields);
    }


    public Stream aggregate(Fields inputFields, ReducerAggregator agg, Fields functionFields) {
        projectionValidation(inputFields);
        return chainedAgg().aggregate(inputFields, agg, functionFields).chainEnd();
    }


    public TridentState partitionPersist(StateFactory stateFactory, StateUpdater updater,
            Fields functionFields) {
        return partitionPersist(new StateSpec(stateFactory), updater, functionFields);
    }


    public TridentState partitionPersist(StateSpec stateSpec, StateUpdater updater, Fields functionFields) {
        return partitionPersist(stateSpec, null, updater, functionFields);
    }


    public TridentState partitionPersist(StateFactory stateFactory, StateUpdater updater) {
        return partitionPersist(stateFactory, updater, new Fields());
    }


    public TridentState partitionPersist(StateSpec stateSpec, StateUpdater updater) {
        return partitionPersist(stateSpec, updater, new Fields());
    }


    public TridentState persistentAggregate(StateFactory stateFactory, CombinerAggregator agg,
            Fields functionFields) {
        return persistentAggregate(new StateSpec(stateFactory), agg, functionFields);
    }


    public TridentState persistentAggregate(StateSpec spec, CombinerAggregator agg, Fields functionFields) {
        return persistentAggregate(spec, null, agg, functionFields);
    }


    public TridentState persistentAggregate(StateFactory stateFactory, Fields inputFields,
            CombinerAggregator agg, Fields functionFields) {
        return persistentAggregate(new StateSpec(stateFactory), inputFields, agg, functionFields);
    }


    public TridentState persistentAggregate(StateSpec spec, Fields inputFields, CombinerAggregator agg,
            Fields functionFields) {
        projectionValidation(inputFields);
        // replaces normal aggregation here with a global grouping because it
        // needs to be consistent across batches
        return new ChainedAggregatorDeclarer(this, new GlobalAggScheme())
            .aggregate(inputFields, agg, functionFields).chainEnd()
            .partitionPersist(spec, functionFields, new CombinerAggStateUpdater(agg), functionFields);
    }


    public TridentState persistentAggregate(StateFactory stateFactory, ReducerAggregator agg,
            Fields functionFields) {
        return persistentAggregate(new StateSpec(stateFactory), agg, functionFields);
    }


    public TridentState persistentAggregate(StateSpec spec, ReducerAggregator agg, Fields functionFields) {
        return persistentAggregate(spec, null, agg, functionFields);
    }


    public TridentState persistentAggregate(StateFactory stateFactory, Fields inputFields,
            ReducerAggregator agg, Fields functionFields) {
        return persistentAggregate(new StateSpec(stateFactory), inputFields, agg, functionFields);
    }


    public TridentState persistentAggregate(StateSpec spec, Fields inputFields, ReducerAggregator agg,
            Fields functionFields) {
        projectionValidation(inputFields);
        return global().partitionPersist(spec, inputFields, new ReducerAggStateUpdater(agg), functionFields);
    }


    public Stream stateQuery(TridentState state, QueryFunction function, Fields functionFields) {
        return stateQuery(state, null, function, functionFields);
    }


    @Override
    public Stream toStream() {
        return this;
    }


    @Override
    public Fields getOutputFields() {
        return _node.allOutputFields;
    }

    static class BatchGlobalAggScheme implements GlobalAggregationScheme<Stream> {

        @Override
        public IAggregatableStream aggPartition(Stream s) {
            return s.batchGlobal();
        }


        @Override
        public BatchToPartition singleEmitPartitioner() {
            return new IndexHashBatchToPartition();
        }

    }

    static class GlobalAggScheme implements GlobalAggregationScheme<Stream> {

        @Override
        public IAggregatableStream aggPartition(Stream s) {
            return s.global();
        }


        @Override
        public BatchToPartition singleEmitPartitioner() {
            return new GlobalBatchToPartition();
        }

    }


    private void projectionValidation(Fields projFields) {
        if (projFields == null) {
            return;
        }

        Fields allFields = this.getOutputFields();
        for (String field : projFields) {
            if (!allFields.contains(field)) {
                throw new IllegalArgumentException("Trying to select non-existent field: '" + field
                        + "' from stream containing fields fields: <" + allFields + ">");
            }
        }
    }
}
